/*
 * SPDX-FileCopyrightText: 2025 vLLM Team
 * SPDX-FileCopyrightText: 2025 Qingcheng.AI
 *
 * SPDX-License-Identifier: Apache-2.0
 */

#ifndef _data_types_cuh
#define _data_types_cuh
#include "marlin.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif

namespace MARLIN_NAMESPACE_NAME {

template <typename scalar_t> class ScalarType {};

template <> class ScalarType<half> {
  public:
    using scalar_t = half;
    using scalar_t2 = half2;

    // Matrix fragments for tensor core instructions; their precise layout is
    // documented here:
    // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
    using FragA = Vec<half2, 4>;
    using FragB = Vec<half2, 2>;
    using FragC = Vec<float, 4>;
    using FragS = Vec<half2, 1>;
    using FragZP = Vec<half2, 4>;

    static __device__ float inline num2float(const half x) {
        return __half2float(x);
    }

    static __device__ half2 inline num2num2(const half x) {
        return __half2half2(x);
    }

    static __device__ half2 inline nums2num2(const half x1, const half x2) {
        return __halves2half2(x1, x2);
    }

    static __host__ __device__ half inline float2num(const float x) {
        return __float2half(x);
    }
};

template <> class ScalarType<nv_bfloat16> {
  public:
    using scalar_t = nv_bfloat16;
    using scalar_t2 = nv_bfloat162;

    using FragA = Vec<nv_bfloat162, 4>;
    using FragB = Vec<nv_bfloat162, 2>;
    using FragC = Vec<float, 4>;
    using FragS = Vec<nv_bfloat162, 1>;
    using FragZP = Vec<nv_bfloat162, 4>;

#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
    static __device__ float inline num2float(const nv_bfloat16 x) {
        return __bfloat162float(x);
    }

    static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
        return __bfloat162bfloat162(x);
    }

    static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
                                                    const nv_bfloat16 x2) {
        return __halves2bfloat162(x1, x2);
    }

    static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
        return __float2bfloat16(x);
    }
#endif
};

} // namespace MARLIN_NAMESPACE_NAME

#endif
